# Copyright (c) Facebook, Inc. and its affiliates.
import copy
import logging

import cv2
import numpy as np
import torch
from detectron2.config import configurable
from detectron2.data import MetadataCatalog
from detectron2.data import detection_utils as utils
from detectron2.data import transforms as T
from detectron2.structures import BitMasks, Boxes, Instances
from torch.nn import functional as F

from maskdino.data.detection_utils import read_dianjiao_image
from maskdino.data.transforms.color_augmentation import ColorAugSSDTransform
from .mask_former_semantic_dataset_mapper import MaskFormerSemanticDatasetMapper

__all__ = ["DeepLabSemanticDatasetMapper"]


class DeepLabSemanticDatasetMapper(MaskFormerSemanticDatasetMapper):
    """
    A callable which takes a dataset dict in Detectron2 Dataset format,
    and map it into a format used by MaskFormer for semantic segmentation.

    The callable currently does the following:

    1. Read the image from "file_name"
    2. Applies geometric transforms to the image and annotation
    3. Find and applies suitable cropping to the image and annotation
    4. Prepare image and annotation to Tensors
    """
    def __call__(self, dataset_dict):
        """
        Args:
            dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format.

        Returns:
            dict: a format that builtin models in detectron2 accept
        """
        assert self.is_train, "MaskFormerSemanticDatasetMapper should only be used for training!"

        dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
        
        # image = utils.read_image(dataset_dict["file_name"], format=self.img_format)
        image = read_dianjiao_image(dataset_dict["file_name"], format=self.img_format)
        utils.check_image_size(dataset_dict, image)

        if "sem_seg_file_name" in dataset_dict:
            # PyTorch transformation not implemented for uint16, so converting it to double first
            if "0_255_type" in dataset_dict:
                sem_seg_gt = cv2.imread(dataset_dict.pop("sem_seg_file_name"), cv2.IMREAD_GRAYSCALE)
                sem_seg_gt = cv2.threshold(sem_seg_gt, 125, 255, cv2.THRESH_BINARY)[1]
                h, w = sem_seg_gt.shape
                if h > w:
                    sem_seg_gt = cv2.rotate(sem_seg_gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
                sem_seg_gt[sem_seg_gt==255] = 1
                sem_seg_gt = sem_seg_gt.astype("double")
            else:
                # sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
                sem_seg_gt = read_dianjiao_image(dataset_dict.pop("sem_seg_file_name")).astype("double")
        else:
            sem_seg_gt = None

        if sem_seg_gt is None:
            raise ValueError(
                "Cannot find 'sem_seg_file_name' for semantic segmentation dataset {}.".format(
                    dataset_dict["file_name"]
                )
            )

        aug_input = T.AugInput(image, sem_seg=sem_seg_gt)
        aug_input, transforms = T.apply_transform_gens(self.tfm_gens, aug_input)
        image = aug_input.image
        sem_seg_gt = aug_input.sem_seg

        # Pad image and segmentation label here!
        image = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1)))
        if sem_seg_gt is not None:
            sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long"))

        if self.size_divisibility > 0:
            image_size = (image.shape[-2], image.shape[-1])
            padding_size = [
                0,
                self.size_divisibility - image_size[1],
                0,
                self.size_divisibility - image_size[0],
            ]
            image = F.pad(image, padding_size, value=128).contiguous()
            if sem_seg_gt is not None:
                sem_seg_gt = F.pad(sem_seg_gt, padding_size, value=self.ignore_label).contiguous()

        image_shape = (image.shape[-2], image.shape[-1])  # h, w

        # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory,
        # but not efficient on large generic data structures due to the use of pickle & mp.Queue.
        # Therefore it's important to use torch.Tensor.
        dataset_dict["image"] = image

        if sem_seg_gt is not None:
            dataset_dict["sem_seg"] = sem_seg_gt.long()

        if "annotations" in dataset_dict:
            raise ValueError("Semantic segmentation dataset should not have 'annotations'.")

        # Prepare per-category binary masks
        # if sem_seg_gt is not None:
        #     sem_seg_gt = sem_seg_gt.numpy()
        #     instances = Instances(image_shape)
        #     classes = np.unique(sem_seg_gt)
        #     # remove ignored region
        #     classes = classes[classes != self.ignore_label]
        #     instances.gt_classes = torch.tensor(classes, dtype=torch.int64)

        #     masks = []
        #     for class_id in classes:
        #         masks.append(sem_seg_gt == class_id)

        #     if len(masks) == 0:
        #         # Some image does not have annotation (all ignored)
        #         instances.gt_masks = torch.zeros((0, sem_seg_gt.shape[-2], sem_seg_gt.shape[-1]))
        #         instances.gt_boxes = Boxes(torch.zeros((0,4)))
        #     else:
        #         masks = BitMasks(
        #             torch.stack([torch.from_numpy(np.ascontiguousarray(x.copy())) for x in masks])
        #         )
        #         instances.gt_masks = masks.tensor
        #         instances.gt_boxes = masks.get_bounding_boxes()

        #     dataset_dict["instances"] = instances

        return dataset_dict
